{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "import json\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sentence1': '蚂蚁借呗等额还款可以换成先息后本吗', 'sentence2': '借呗有先息到期还本吗', 'label': '0'}\n"
     ]
    }
   ],
   "source": [
    "class AFQMC(Dataset):\n",
    "    def __init__(self, data_file):\n",
    "        self.data = self.load_data(data_file)\n",
    "    \n",
    "    def load_data(self, data_file):\n",
    "        Data = {}\n",
    "        with open(data_file, 'rt') as f:\n",
    "            for idx, line in enumerate(f):\n",
    "                sample = json.loads(line.strip())\n",
    "                Data[idx] = sample\n",
    "        return Data\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx]\n",
    "\n",
    "train_data = AFQMC('/Users/mac/Downloads/datasets/afqmc_public/train.json')\n",
    "valid_data = AFQMC('/Users/mac/Downloads/datasets/afqmc_public/dev.json')\n",
    "\n",
    "print(train_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sentence1': '蚂蚁借呗等额还款可以换成先息后本吗', 'sentence2': '借呗有先息到期还本吗', 'label': '0'}\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import IterableDataset\n",
    "\n",
    "class IterableAFQMC(IterableDataset):\n",
    "    def __init__(self, data_file):\n",
    "        self.data_file = data_file\n",
    "\n",
    "    def __iter__(self):\n",
    "        with open(self.data_file, 'rt') as f:\n",
    "            for line in f:\n",
    "                sample = json.loads(line.strip())\n",
    "                yield sample\n",
    "\n",
    "train_data1 = IterableAFQMC('/Users/mac/Downloads/datasets/afqmc_public/train.json')\n",
    "\n",
    "print(next(iter(train_data1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from transformers import AutoTokenizer\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb1ad4f9c18c4146aa356a1147b85271",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62ce9acfbc764c9d847d1a5b78be0eaf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/324 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0edffe73093d4dc09003aca7ade87e39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.txt: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "52313dc6ecc54852b6594be3512e05a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "checkpoint = \"bert-base-chinese\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collote_fn(batch_samples):\n",
    "    batch_sentence_1, batch_sentence_2 = [], []\n",
    "    batch_label = []\n",
    "    for sample in batch_samples:\n",
    "        batch_sentence_1.append(sample['sentence1'])\n",
    "        batch_sentence_2.append(sample['sentence2'])\n",
    "        batch_label.append(int(sample['label']))\n",
    "    X = tokenizer(\n",
    "        batch_sentence_1, \n",
    "        batch_sentence_2, \n",
    "        padding=True, \n",
    "        truncation=True, \n",
    "        return_tensors=\"pt\"\n",
    "    )\n",
    "    y = torch.tensor(batch_label)\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch_X shape: {'input_ids': torch.Size([4, 43]), 'token_type_ids': torch.Size([4, 43]), 'attention_mask': torch.Size([4, 43])}\n",
      "batch_y shape: torch.Size([4])\n",
      "{'input_ids': tensor([[ 101,  955, 1446,  711,  784,  720, 1372, 3300,  955, 3621, 3309, 7361,\n",
      "         1063,  702, 3299, 6821,  671,  702, 6848, 7555,  102,  955, 1446,  955,\n",
      "         3621, 1453, 3309,  711,  784,  720, 5367, 4764,  749,  102,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0],\n",
      "        [ 101,  928, 4500, 1305, 5709, 1446, 3119, 3621, 3302, 1218, 1963,  862,\n",
      "         6369, 5050, 3302, 1218, 6589,  102,  886, 4500,  928, 4500, 1305, 1469,\n",
      "         5709, 1446, 3119, 1357, 4638, 3302, 1218, 6589, 3221, 3119, 4638, 7560,\n",
      "         2145, 4638, 6820, 3221, 2769, 4638,  102],\n",
      "        [ 101, 2769, 3221, 6206, 3389, 5709, 1446,  102, 3221, 5709, 1446, 4638,\n",
      "         7309, 7579, 1506,  102,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0],\n",
      "        [ 101, 5709, 1446, 6572, 1296, 1377,  809, 3241, 6820, 3621, 1126, 1921,\n",
      "          102, 5709, 1446, 1126, 1384, 3221, 6820, 3621, 3189,  102,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n",
      "tensor([1, 0, 0, 0])\n"
     ]
    }
   ],
   "source": [
    "train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)\n",
    "valid_dataloader= DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collote_fn)\n",
    "\n",
    "batch_X, batch_y = next(iter(train_dataloader))\n",
    "print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})\n",
    "print('batch_y shape:', batch_y.shape)\n",
    "print(batch_X)\n",
    "print(batch_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModel\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cpu device\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'Using {device} device')\n",
    "\n",
    "bert_encoder = AutoModel.from_pretrained(checkpoint)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BertForPairwiseCLS(\n",
      "  (bert_encoder): BertModel(\n",
      "    (embeddings): BertEmbeddings(\n",
      "      (word_embeddings): Embedding(21128, 768, padding_idx=0)\n",
      "      (position_embeddings): Embedding(512, 768)\n",
      "      (token_type_embeddings): Embedding(2, 768)\n",
      "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "    (encoder): BertEncoder(\n",
      "      (layer): ModuleList(\n",
      "        (0-11): 12 x BertLayer(\n",
      "          (attention): BertAttention(\n",
      "            (self): BertSelfAttention(\n",
      "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "            (output): BertSelfOutput(\n",
      "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (intermediate): BertIntermediate(\n",
      "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            (intermediate_act_fn): GELUActivation()\n",
      "          )\n",
      "          (output): BertOutput(\n",
      "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (pooler): BertPooler(\n",
      "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (activation): Tanh()\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class BertForPairwiseCLS(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BertForPairwiseCLS, self).__init__()\n",
    "        self.bert_encoder = bert_encoder\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "        self.classifier = nn.Linear(768, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        bert_output = self.bert_encoder(**x)\n",
    "        #print(\"BertForPairwiseCLS \", bert_output.shape)\n",
    "        cls_vectors = bert_output.last_hidden_state[:, 0, :]\n",
    "        cls_vectors = self.dropout(cls_vectors)\n",
    "        logits = self.classifier(cls_vectors)\n",
    "        return logits\n",
    "\n",
    "model = BertForPairwiseCLS().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForPairwiseCLS were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BertForPairwiseCLS(\n",
      "  (bert): BertModel(\n",
      "    (embeddings): BertEmbeddings(\n",
      "      (word_embeddings): Embedding(21128, 768, padding_idx=0)\n",
      "      (position_embeddings): Embedding(512, 768)\n",
      "      (token_type_embeddings): Embedding(2, 768)\n",
      "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "    (encoder): BertEncoder(\n",
      "      (layer): ModuleList(\n",
      "        (0-11): 12 x BertLayer(\n",
      "          (attention): BertAttention(\n",
      "            (self): BertSelfAttention(\n",
      "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "            (output): BertSelfOutput(\n",
      "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (intermediate): BertIntermediate(\n",
      "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            (intermediate_act_fn): GELUActivation()\n",
      "          )\n",
      "          (output): BertOutput(\n",
      "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoConfig\n",
    "from transformers import BertPreTrainedModel, BertModel\n",
    "\n",
    "class BertForPairwiseCLS(BertPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.bert = BertModel(config, add_pooling_layer=False)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "        self.classifier = nn.Linear(768, 2)\n",
    "        self.post_init()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        bert_output = self.bert(**x)\n",
    "        cls_vectors = bert_output.last_hidden_state[:, 0, :]\n",
    "        cls_vectors = self.dropout(cls_vectors)\n",
    "        logits = self.classifier(cls_vectors)\n",
    "        return logits\n",
    "\n",
    "config = AutoConfig.from_pretrained(checkpoint)\n",
    "model = BertForPairwiseCLS.from_pretrained(checkpoint, config=config).to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[ 101,  955, 1446, 1377,  809, 6820, 3621, 3189, 3241,  677, 6820, 1408,\n",
      "          102, 1377,  809, 2454, 6826, 6010, 6009,  955, 1446, 6820, 3621, 3189,\n",
      "         1408,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0],\n",
      "        [ 101, 2769, 1762, 3905, 2140, 4500, 5709, 1446,  743,  691, 6205, 8024,\n",
      "         4385, 1762, 4509, 6435, 6842, 3621, 8024, 1297, 2157, 2768, 1216, 6842,\n",
      "         3621,  749, 8024, 1377, 3221, 2769, 5709, 1446, 6820, 3621, 7032, 7583,\n",
      "          679, 1359, 8024, 2582,  720, 1726,  752,  102, 2769, 3905, 2140,  677,\n",
      "         4500, 5709, 1446,  743, 4638,  691, 6205, 8024, 4197, 1400, 6842, 3621,\n",
      "         1168, 5709, 1446, 8024,  711,  784,  720, 6821,  702, 3299, 2769, 6820,\n",
      "         6206, 6820, 6821,  702, 7178,  102],\n",
      "        [ 101, 4509, 6435,  955, 1446, 6432, 6572, 2787, 3300, 7599, 7372,  102,\n",
      "          955, 1446, 6572, 2787, 6822, 1057, 2128, 1059, 7599, 7372, 2347, 5307,\n",
      "         1282, 1921,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0],\n",
      "        [ 101, 3221, 2458, 6858, 5709, 1446,  102,  711,  784,  720, 2769, 2458,\n",
      "         6858, 5709, 1446, 2218, 3221,  782, 3698, 4255, 1355, 4924, 1400, 1086,\n",
      "         6407, 8024, 3221,  784,  720, 2692, 2590,  102,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "            0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0]])}\n",
      "torch.Size([4, 2]) tensor([[ 0.0967, -0.9930],\n",
      "        [ 0.0382, -0.9451],\n",
      "        [ 0.4153, -0.8910],\n",
      "        [ 0.2418, -1.0013]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(batch_X)\n",
    "outputs = model(batch_X)\n",
    "print(outputs.shape, outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):\n",
    "    progress_bar = tqdm(range(len(dataloader)))\n",
    "    progress_bar.set_description(f'loss: {0:>7f}')\n",
    "    finish_step_num = (epoch-1)*len(dataloader)\n",
    "    \n",
    "    model.train()\n",
    "    for step, (X, y) in enumerate(dataloader, start=1):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')\n",
    "        progress_bar.update(1)\n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_loop(dataloader, model, mode='Test'):\n",
    "    assert mode in ['Valid', 'Test']\n",
    "    size = len(dataloader.dataset)\n",
    "    correct = 0\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred = model(X)\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "\n",
    "    correct /= size\n",
    "    print(f\"{mode} Accuracy: {(100*correct):>0.1f}%\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25752\n"
     ]
    }
   ],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "learning_rate = 1e-5\n",
    "epoch_num = 3\n",
    "num_training_steps = epoch_num * len(train_dataloader)\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")\n",
    "print(num_training_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "-------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5fa7b71e5344ef3a86fa749ec9ddac2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8584 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[29], line 21\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epoch_num):\n\u001b[1;32m     20\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mt\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch_num\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m-------------------------------\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 21\u001b[0m     total_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloss_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtotal_loss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     22\u001b[0m     test_loop(valid_dataloader, model, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mValid\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     23\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDone!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "Cell \u001b[0;32mIn[21], line 9\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, (X, y) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(dataloader, start\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m      8\u001b[0m     X, y \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mto(device), y\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 9\u001b[0m     pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m     loss \u001b[38;5;241m=\u001b[39m loss_fn(pred, y)\n\u001b[1;32m     12\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[28], line 9\u001b[0m, in \u001b[0;36mBertForPairwiseCLS.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m----> 9\u001b[0m     bert_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert_encoder\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;66;03m#print(\"BertForPairwiseCLS \", bert_output.shape)\u001b[39;00m\n\u001b[1;32m     11\u001b[0m     cls_vectors \u001b[38;5;241m=\u001b[39m bert_output\u001b[38;5;241m.\u001b[39mlast_hidden_state[:, \u001b[38;5;241m0\u001b[39m, :]\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/transformers/models/bert/modeling_bert.py:988\u001b[0m, in \u001b[0;36mBertModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    979\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m    981\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(\n\u001b[1;32m    982\u001b[0m     input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m    983\u001b[0m     position_ids\u001b[38;5;241m=\u001b[39mposition_ids,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    986\u001b[0m     past_key_values_length\u001b[38;5;241m=\u001b[39mpast_key_values_length,\n\u001b[1;32m    987\u001b[0m )\n\u001b[0;32m--> 988\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    989\u001b[0m \u001b[43m    \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    990\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    991\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    992\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    993\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_extended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    994\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    995\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    996\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    997\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    998\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    999\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1000\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1001\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/transformers/models/bert/modeling_bert.py:582\u001b[0m, in \u001b[0;36mBertEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    571\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    572\u001b[0m         layer_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    573\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    579\u001b[0m         output_attentions,\n\u001b[1;32m    580\u001b[0m     )\n\u001b[1;32m    581\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 582\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    583\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    584\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    585\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    586\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    587\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    588\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    589\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    590\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    592\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/transformers/models/bert/modeling_bert.py:514\u001b[0m, in \u001b[0;36mBertLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    511\u001b[0m     cross_attn_present_key_value \u001b[38;5;241m=\u001b[39m cross_attention_outputs[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m    512\u001b[0m     present_key_value \u001b[38;5;241m=\u001b[39m present_key_value \u001b[38;5;241m+\u001b[39m cross_attn_present_key_value\n\u001b[0;32m--> 514\u001b[0m layer_output \u001b[38;5;241m=\u001b[39m \u001b[43mapply_chunking_to_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    515\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeed_forward_chunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mchunk_size_feed_forward\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mseq_len_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_output\u001b[49m\n\u001b[1;32m    516\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    517\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (layer_output,) \u001b[38;5;241m+\u001b[39m outputs\n\u001b[1;32m    519\u001b[0m \u001b[38;5;66;03m# if decoder, return the attn key/values as the last output\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/transformers/pytorch_utils.py:237\u001b[0m, in \u001b[0;36mapply_chunking_to_forward\u001b[0;34m(forward_fn, chunk_size, chunk_dim, *input_tensors)\u001b[0m\n\u001b[1;32m    234\u001b[0m     \u001b[38;5;66;03m# concatenate output at same dimension\u001b[39;00m\n\u001b[1;32m    235\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcat(output_chunks, dim\u001b[38;5;241m=\u001b[39mchunk_dim)\n\u001b[0;32m--> 237\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minput_tensors\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/transformers/models/bert/modeling_bert.py:526\u001b[0m, in \u001b[0;36mBertLayer.feed_forward_chunk\u001b[0;34m(self, attention_output)\u001b[0m\n\u001b[1;32m    525\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfeed_forward_chunk\u001b[39m(\u001b[38;5;28mself\u001b[39m, attention_output):\n\u001b[0;32m--> 526\u001b[0m     intermediate_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mintermediate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattention_output\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    527\u001b[0m     layer_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput(intermediate_output, attention_output)\n\u001b[1;32m    528\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m layer_output\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/transformers/models/bert/modeling_bert.py:426\u001b[0m, in \u001b[0;36mBertIntermediate.forward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m    425\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m--> 426\u001b[0m     hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdense\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    427\u001b[0m     hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mintermediate_act_fn(hidden_states)\n\u001b[1;32m    428\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m hidden_states\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import random\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "def seed_everything(seed=1029):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    # some cudnn methods can be random even after fixing the seed\n",
    "    # unless you tell it to be deterministic\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "seed_everything(42)\n",
    "\n",
    "total_loss = 0.\n",
    "for t in range(epoch_num):\n",
    "    print(f\"Epoch {t+1}/{epoch_num}\\n-------------------------------\")\n",
    "    total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)\n",
    "    test_loop(valid_dataloader, model, mode='Valid')\n",
    "    print(\"Done!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
